#!/usr/bin/env python3
# -*- coding: utf-8 -*-
from base64 import b64encode
from hashlib import sha1
import os
import socket
import ssl

from ambari_ws4py import WS_KEY, WS_VERSION
from ambari_ws4py.exc import HandshakeError
from ambari_ws4py.websocket import WebSocket
from ambari_ws4py.compat import urlsplit

__all__ = ["WebSocketBaseClient"]


class WebSocketBaseClient(WebSocket):
  def __init__(
    self,
    url,
    protocols=None,
    extensions=None,
    heartbeat_freq=None,
    ssl_options=None,
    headers=None,
    exclude_headers=None,
  ):
    """
    A websocket client that implements :rfc:`6455` and provides a simple
    interface to communicate with a websocket server.

    This class works on its own but will block if not run in
    its own thread.

    When an instance of this class is created, a :py:mod:`socket`
    is created. If the connection is a TCP socket,
    the nagle's algorithm is disabled.

    The address of the server will be extracted from the given
    websocket url.

    The websocket key is randomly generated, reset the
    `key` attribute if you want to provide yours.

    For instance to create a TCP client:

    .. code-block:: python

       >>> from ambari_ws4py.client import WebSocketBaseClient
       >>> ws = WebSocketBaseClient('ws://localhost/ws')


    Here is an example for a TCP client over SSL:

    .. code-block:: python

       >>> from ambari_ws4py.client import WebSocketBaseClient
       >>> ws = WebSocketBaseClient('wss://localhost/ws')


    Finally an example of a Unix-domain connection:

    .. code-block:: python

       >>> from ambari_ws4py.client import WebSocketBaseClient
       >>> ws = WebSocketBaseClient('ws+unix:///tmp/my.sock')

    Note that in this case, the initial Upgrade request
    will be sent to ``/``. You may need to change this
    by setting the resource explicitely before connecting:

    .. code-block:: python

       >>> from ambari_ws4py.client import WebSocketBaseClient
       >>> ws = WebSocketBaseClient('ws+unix:///tmp/my.sock')
       >>> ws.resource = '/ws'
       >>> ws.connect()

    You may provide extra headers by passing a list of tuples
    which must be unicode objects.

    """
    self.url = url
    self.host = None
    self.scheme = None
    self.port = None
    self.unix_socket_path = None
    self.resource = None
    self.ssl_options = ssl_options or {}
    self.extra_headers = headers or []
    self.exclude_headers = exclude_headers or []
    self.exclude_headers = [x.lower() for x in self.exclude_headers]

    if self.scheme == "wss":
      # Prevent check_hostname requires server_hostname (ref #187)
      if "cert_reqs" not in self.ssl_options:
        self.ssl_options["cert_reqs"] = ssl.CERT_NONE

    self._parse_url()

    if self.unix_socket_path:
      sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM, 0)
    else:
      # Let's handle IPv4 and IPv6 addresses
      # Simplified from CherryPy's code
      try:
        family, socktype, proto, canonname, sa = socket.getaddrinfo(
          self.host,
          self.port,
          socket.AF_UNSPEC,
          socket.SOCK_STREAM,
          0,
          socket.AI_PASSIVE,
        )[0]
      except socket.gaierror:
        family = socket.AF_INET
        if self.host.startswith("::"):
          family = socket.AF_INET6

        socktype = socket.SOCK_STREAM
        proto = 0
        canonname = ""
        sa = (self.host, self.port, 0, 0)

      sock = socket.socket(family, socktype, proto)
      sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1)
      sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
      if (
        hasattr(socket, "AF_INET6")
        and family == socket.AF_INET6
        and self.host.startswith("::")
      ):
        try:
          sock.setsockopt(socket.IPPROTO_IPV6, socket.IPV6_V6ONLY, 0)
        except (AttributeError, socket.error):
          pass

    WebSocket.__init__(
      self,
      sock,
      protocols=protocols,
      extensions=extensions,
      heartbeat_freq=heartbeat_freq,
    )

    self.stream.always_mask = True
    self.stream.expect_masking = False
    self.key = b64encode(os.urandom(16))

  # Adpated from: https://github.com/liris/websocket-client/blob/master/websocket.py#L105
  def _parse_url(self):
    """
    Parses a URL which must have one of the following forms:

    - ws://host[:port][path]
    - wss://host[:port][path]
    - ws+unix:///path/to/my.socket

    In the first two cases, the ``host`` and ``port``
    attributes will be set to the parsed values. If no port
    is explicitely provided, it will be either 80 or 443
    based on the scheme. Also, the ``resource`` attribute is
    set to the path segment of the URL (alongside any querystring).

    In addition, if the scheme is ``ws+unix``, the
    ``unix_socket_path`` attribute is set to the path to
    the Unix socket while the ``resource`` attribute is
    set to ``/``.
    """
    # Python 2.6.1 and below don't parse ws or wss urls properly. netloc is empty.
    # See: https://github.com/Lawouach/WebSocket-for-Python/issues/59
    scheme, url = self.url.split(":", 1)

    parsed = urlsplit(url, scheme="http")
    if parsed.hostname:
      self.host = parsed.hostname
    elif "+unix" in scheme:
      self.host = "localhost"
    else:
      raise ValueError("Invalid hostname from: %s", self.url)

    if parsed.port:
      self.port = parsed.port

    if scheme == "ws":
      if not self.port:
        self.port = 80
    elif scheme == "wss":
      if not self.port:
        self.port = 443
    elif scheme in ("ws+unix", "wss+unix"):
      pass
    else:
      raise ValueError(f"Invalid scheme: {scheme}")

    if parsed.path:
      resource = parsed.path
    else:
      resource = "/"

    if "+unix" in scheme:
      self.unix_socket_path = resource
      resource = "/"

    if parsed.query:
      resource += "?" + parsed.query

    self.scheme = scheme
    self.resource = resource

  @property
  def bind_addr(self):
    """
    Returns the Unix socket path if or a tuple
    ``(host, port)`` depending on the initial
    URL's scheme.
    """
    return self.unix_socket_path or (self.host, self.port)

  def close(self, code=1000, reason=""):
    """
    Initiate the closing handshake with the server.
    """
    if not self.client_terminated:
      self.client_terminated = True
      self._write(self.stream.close(code=code, reason=reason).single(mask=True))

  def connect(self):
    """
    Connects this websocket and starts the upgrade handshake
    with the remote endpoint.
    """
    if self.scheme == "wss":
      # default port is now 443; upgrade self.sender to send ssl
      self.sock = ssl.wrap_socket(self.sock, **self.ssl_options)
      self._is_secure = True

    self.sock.settimeout(10.0)
    self.sock.connect(self.bind_addr)

    self._write(self.handshake_request)

    response = b""
    doubleCLRF = b"\r\n\r\n"
    while True:
      bytes = self.sock.recv(128)
      if not bytes:
        break
      response += bytes
      if doubleCLRF in response:
        break

    if not response:
      self.close_connection()
      raise HandshakeError("Invalid response")

    headers, _, body = response.partition(doubleCLRF)
    response_line, _, headers = headers.partition(b"\r\n")

    try:
      self.process_response_line(response_line)
      self.protocols, self.extensions = self.process_handshake_header(headers)
    except HandshakeError:
      self.close_connection()
      raise

    self.handshake_ok()
    if body:
      self.process(body)

  @property
  def handshake_headers(self):
    """
    List of headers appropriate for the upgrade
    handshake.
    """
    headers = [
      ("Host", f"{self.host}:{self.port}"),
      ("Connection", "Upgrade"),
      ("Upgrade", "websocket"),
      ("Sec-WebSocket-Key", self.key.decode("utf-8")),
      ("Sec-WebSocket-Version", str(max(WS_VERSION))),
    ]

    if self.protocols:
      headers.append(("Sec-WebSocket-Protocol", ",".join(self.protocols)))

    if self.extra_headers:
      headers.extend(self.extra_headers)

    if (
      not any(x for x in headers if x[0].lower() == "origin")
      and "origin" not in self.exclude_headers
    ):
      scheme, url = self.url.split(":", 1)
      parsed = urlsplit(url, scheme="http")
      if parsed.hostname:
        self.host = parsed.hostname
      else:
        self.host = "localhost"
      origin = scheme + "://" + self.host
      if parsed.port:
        origin = origin + ":" + str(parsed.port)
      headers.append(("Origin", origin))

    headers = [x for x in headers if x[0].lower() not in self.exclude_headers]

    return headers

  @property
  def handshake_request(self):
    """
    Prepare the request to be sent for the upgrade handshake.
    """
    headers = self.handshake_headers
    request = [f"GET {self.resource} HTTP/1.1".encode("utf-8")]
    for header, value in headers:
      request.append(f"{header}: {value}".encode("utf-8"))
    request.append(b"\r\n")

    return b"\r\n".join(request)

  def process_response_line(self, response_line):
    """
    Ensure that we received a HTTP `101` status code in
    response to our request and if not raises :exc:`HandshakeError`.
    """
    protocol, code, status = response_line.split(b" ", 2)
    if code != b"101":
      raise HandshakeError(f"Invalid response status: {code} {status}")

  def process_handshake_header(self, headers):
    """
    Read the upgrade handshake's response headers and
    validate them against :rfc:`6455`.
    """
    protocols = []
    extensions = []

    headers = headers.strip()

    for header_line in headers.split(b"\r\n"):
      header, value = header_line.split(b":", 1)
      header = header.strip().lower()
      value = value.strip().lower()

      if header == b"upgrade" and value != b"websocket":
        raise HandshakeError(f"Invalid Upgrade header: {value}")

      elif header == b"connection" and value != b"upgrade":
        raise HandshakeError(f"Invalid Connection header: {value}")

      elif header == b"sec-websocket-accept":
        match = b64encode(sha1(self.key + WS_KEY).digest())
        if value != match.lower():
          raise HandshakeError(f"Invalid challenge response: {value}")

      elif header == b"sec-websocket-protocol":
        protocols.extend([x.strip() for x in value.split(b",")])

      elif header == b"sec-websocket-extensions":
        extensions.extend([x.strip() for x in value.split(b",")])

    return protocols, extensions

  def handshake_ok(self):
    self.opened()
